#!/usr/bin/python3

import gc
import threading
import time
import weakref
from pathlib import Path
from sqlite3 import OperationalError
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Final,
    Iterator,
    List,
    Optional,
    Tuple,
    Union,
    final,
)

from eth_typing import BlockNumber, ChecksumAddress, HexAddress, HexStr
from web3.datastructures import AttributeDict
from web3.types import BlockData

import brownie.network.rpc as rpc
from brownie._c_constants import sha1
from brownie._config import CONFIG, _get_data_folder
from brownie._singleton import _Singleton
from brownie.convert import Wei
from brownie.exceptions import BrownieEnvironmentError, CompilerError
from brownie.project.build import DEPLOYMENT_KEYS
from brownie.typing import ContractBuildJson, ContractName, Count, PCMap, ProgramCounter
from brownie.utils import bytes_to_hexstring
from brownie.utils.sql import Cursor

from .transaction import TransactionReceipt
from .web3 import _resolve_address, web3

if TYPE_CHECKING:
    from .contract import Contract, ProjectContract

PathMap = Dict[str, Tuple[HexStr, str]]
Deployment = Tuple[ContractBuildJson, Dict[str, Any]]

AnyContract = Union["Contract", "ProjectContract"]

_contract_map: Final[Dict[ChecksumAddress, AnyContract]] = {}
_revert_refs: Final[List[weakref.ReferenceType]] = []

cur: Final = Cursor(_get_data_folder().joinpath("deployments.db"))
cur.execute("CREATE TABLE IF NOT EXISTS sources (hash PRIMARY KEY, source)")


@final
class TxHistory(metaclass=_Singleton):
    """List-like singleton container that contains TransactionReceipt objects.
    Whenever a transaction is broadcast, the TransactionReceipt is automatically
    added to this container."""

    def __init__(self) -> None:
        self._list: List[TransactionReceipt] = []
        self.gas_profile: Final[Dict[str, Dict[str, int]]] = {}
        _revert_register(self)

    def __repr__(self) -> str:
        if CONFIG.argv["cli"] == "console":
            return str(self._list)
        return super().__repr__()

    def __getattribute__(self, name: str) -> Any:
        # filter dropped transactions prior to attribute access
        items: List[TransactionReceipt] = super().__getattribute__("_list")
        items = [i for i in items if i.status != -2]
        setattr(self, "_list", items)
        return super().__getattribute__(name)

    def __bool__(self) -> bool:
        return bool(self._list)

    def __contains__(self, item: Any) -> bool:
        return item in self._list

    def __iter__(self) -> Iterator[TransactionReceipt]:
        return iter(self._list)

    def __getitem__(self, key: int) -> TransactionReceipt:
        return self._list[key]

    def __len__(self) -> int:
        return len(self._list)

    def _reset(self) -> None:
        self._list.clear()

    def _revert(self, height: BlockNumber) -> None:
        self._list = [i for i in self._list if i.block_number <= height]  # type: ignore [operator]

    def _add_tx(self, tx: TransactionReceipt) -> None:
        if tx not in self._list:
            self._list.append(tx)

    def clear(self, only_confirmed: bool = False) -> None:
        """
        Clear the list.

        Arguments
        ---------
        only_confirmed : bool, optional
            If True, transactions which are still marked as pending will not be removed.
        """
        if only_confirmed:
            self._list = [i for i in self._list if i.status == -1]
        else:
            self._list.clear()

    def copy(self) -> List[TransactionReceipt]:
        """Returns a shallow copy of the object as a list"""
        return self._list.copy()

    def filter(self, key: Optional[Callable] = None, **kwargs: Any) -> List[TransactionReceipt]:
        """
        Return a filtered list of transactions.

        Arguments
        ---------
        key : Callable, optional
            An optional function to filter with. It should expect one argument and return
            True or False.

        Keyword Arguments
        -----------------
        **kwargs : Any
            Names and expected values for TransactionReceipt attributes.

        Returns
        -------
        List
            A filtered list of TransactionReceipt objects.
        """
        result = (i for i in self._list if all(getattr(i, k) == v for k, v in kwargs.items()))
        return list(result if key is None else filter(key, result))

    def wait(self, key: Optional[Callable] = None, **kwargs: Any) -> None:
        """
        Wait for pending transactions to confirm.

        This method iterates over a list of transactions generated by `TxHistory.filter`,
        waiting until each transaction has confirmed. If no arguments are given, all
        transactions within the container are used.

        Arguments
        ---------
        key : Callable, optional
            An optional function to filter with. It should expect one argument and return
            True or False.

        Keyword Arguments
        -----------------
        **kwargs : Any
            Names and expected values for TransactionReceipt attributes.
        """
        while True:
            pending = next(iter(self.filter(key, status=-1, **kwargs)), None)
            if pending is None:
                return
            pending._confirmed.wait()

    def from_sender(self, account: str) -> List[TransactionReceipt]:
        """Returns a list of transactions where the sender is account"""
        return [i for i in self._list if i.sender == account]

    def to_receiver(self, account: str) -> List[TransactionReceipt]:
        """Returns a list of transactions where the receiver is account"""
        return [i for i in self._list if i.receiver == account]

    def of_address(self, account: str) -> List[TransactionReceipt]:
        """Returns a list of transactions where account is the sender or receiver"""
        return [i for i in self._list if i.receiver == account or i.sender == account]

    def _gas(self, fn_name: str, gas_used: int, is_success: bool) -> None:
        gas = self.gas_profile.setdefault(fn_name, {})
        if not gas:
            gas.update(
                avg=gas_used, high=gas_used, low=gas_used, count=1, count_success=0, avg_success=0
            )
            if is_success:
                gas["count_success"] = 1
                gas["avg_success"] = gas_used
            return
        gas.update(
            avg=(gas["avg"] * gas["count"] + gas_used) // (gas["count"] + 1),
            high=max(gas["high"], gas_used),
            low=min(gas["low"], gas_used),
        )
        gas["count"] += 1
        if is_success:
            count = gas["count_success"]
            gas["count_success"] += 1
            if not gas["avg_success"]:
                gas["avg_success"] = gas_used
            else:
                avg = gas["avg_success"]
                gas["avg_success"] = (avg * count + gas_used) // (count + 1)


@final
class Chain(metaclass=_Singleton):
    """
    List-like singleton used to access block data, and perform actions such as
    snapshotting, mining, and chain rewinds.
    """

    def __init__(self) -> None:
        self._time_offset: int = 0
        self._snapshot_id: Optional[int | str] = None
        self._reset_id: Optional[int | str] = None
        self._current_id: Optional[int | str] = None
        self._undo_lock: Final = threading.Lock()
        self._undo_buffer: Final[List[Tuple[int | str, Any, Tuple[Any, ...], Dict[str, Any]]]] = []
        self._redo_buffer: Final[List[Tuple[Any, Tuple[Any, ...], Dict[str, Any]]]] = []
        self._chainid: Optional[int] = None
        self._block_gas_time: int = -1
        self._block_gas_limit: int = 0

    def __repr__(self) -> str:
        try:
            return f"<Chain object (chainid={self.id}, height={self.height})>"
        except Exception:
            return "<Chain object (disconnected)>"

    def __len__(self) -> int:
        """
        Return the current number of blocks.
        """
        return web3.eth.block_number + 1

    def __getitem__(self, block_number: BlockNumber) -> BlockData | AttributeDict:
        """
        Return information about a block by block number.

        Arguments
        ---------
        block_number : BlockNumber
            Integer of a block number. If the value is negative, the block returned
            is relative to the most recently mined block, e.g. `chain[-1]` returns
            the most recent block.

        Returns
        -------
        BlockData
            web3 block data object
        """
        if not isinstance(block_number, int):
            raise TypeError("Block height must be given as an integer")
        if block_number < 0:
            block_number = web3.eth.block_number + 1 + block_number
        block: BlockData | AttributeDict = web3.eth.get_block(block_number)
        if block["timestamp"] > self._block_gas_time:
            self._block_gas_limit = block["gasLimit"]
            self._block_gas_time = block["timestamp"]
        return block

    def __iter__(self) -> Iterator[BlockData | AttributeDict]:
        get_block = web3.eth.get_block
        for i in range(web3.eth.block_number + 1):
            block: BlockData | AttributeDict = get_block(i)
            yield block

    def new_blocks(
        self, height_buffer: int = 0, poll_interval: int = 5
    ) -> Iterator[BlockData | AttributeDict]:
        """
        Generator for iterating over new blocks.

        Arguments
        ---------
        height_buffer : int, optional
            The number of blocks behind "latest" to return. A higher value means
            more delayed results but less likelihood of uncles.
        poll_interval : int, optional
            Maximum interval between querying for a new block, if the height has
            not changed. Set this lower to detect uncles more frequently.
        """
        if height_buffer < 0:
            raise ValueError("Buffer cannot be negative")

        last_block: Optional[BlockData | AttributeDict] = None
        last_height = 0
        last_poll = 0.0

        get_block = web3.eth.get_block
        while True:
            if last_poll + poll_interval < time.time() or last_height != web3.eth.block_number:
                last_height = web3.eth.block_number
                block: BlockData | AttributeDict = get_block(last_height - height_buffer)
                last_poll = time.time()

                if block != last_block:
                    last_block = block
                    yield last_block
            else:
                time.sleep(1)

    @property
    def height(self) -> BlockNumber:
        return web3.eth.block_number

    @property
    def id(self) -> int:
        if self._chainid is None:
            self._chainid = web3.eth.chain_id
        return self._chainid

    @property
    def block_gas_limit(self) -> Wei:
        if time.time() > self._block_gas_time + 3600:
            block: BlockData | AttributeDict = web3.eth.get_block("latest")
            self._block_gas_limit = block["gasLimit"]
            self._block_gas_time = block["timestamp"]
        return Wei(self._block_gas_limit)

    @property
    def base_fee(self) -> Wei:
        block: BlockData | AttributeDict = web3.eth.get_block("latest")
        return Wei(block["baseFeePerGas"])

    @property
    def priority_fee(self) -> Wei:
        return Wei(web3.eth.max_priority_fee)

    def _revert(self, id_: int | str) -> int | str:
        rpc_client = rpc.Rpc()
        if web3.isConnected() and not web3.eth.block_number and not self._time_offset:
            _notify_registry(0)  # type: ignore [arg-type]
            return rpc_client.snapshot()
        rpc_client.revert(id_)
        id_ = rpc_client.snapshot()
        try:
            self.sleep(0)
        except NotImplementedError:
            pass
        _notify_registry()
        return id_

    def _add_to_undo_buffer(
        self, tx: TransactionReceipt, fn: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any]
    ) -> None:
        with self._undo_lock:
            tx._confirmed.wait()
            self._undo_buffer.append((self._current_id, fn, args, kwargs))  # type: ignore [arg-type]
            redo_buffer = self._redo_buffer
            if redo_buffer and (fn, args, kwargs) == redo_buffer[-1]:
                redo_buffer.pop()
            else:
                redo_buffer.clear()
            self._current_id = rpc.Rpc().snapshot()
            # ensure the local time offset is correct, in case it was modified by the transaction
            self.sleep(0)

    def _network_connected(self) -> None:
        self._reset_id = None
        try:
            self.reset()
        except NotImplementedError:
            # required for geth dev
            _notify_registry(0)  # type: ignore [arg-type]

    def _network_disconnected(self) -> None:
        self._undo_buffer.clear()
        self._redo_buffer.clear()
        self._snapshot_id = None
        self._reset_id = None
        self._current_id = None
        self._chainid = None
        _notify_registry(0)  # type: ignore [arg-type]

    def get_transaction(self, txid: Union[str, bytes]) -> TransactionReceipt:
        """
        Return a TransactionReceipt object for the given transaction hash.
        """
        if not isinstance(txid, str):
            txid = bytes_to_hexstring(txid)
        tx = next((i for i in TxHistory() if i.txid == txid), None)
        return tx or TransactionReceipt(txid, silent=True, required_confs=0)

    def time(self) -> int:
        """Return the current epoch time from the test RPC as an int"""
        return int(time.time() + self._time_offset)

    def sleep(self, seconds: int) -> None:
        """
        Increase the time within the test RPC.

        Arguments
        ---------
        seconds : int
            Number of seconds to increase the time by
        """
        if not isinstance(seconds, int):
            raise TypeError("seconds must be an integer value")
        self._time_offset = int(rpc.Rpc().sleep(seconds))

        if seconds:
            self._redo_buffer.clear()
            self._current_id = rpc.Rpc().snapshot()

    def mine(
        self, blocks: int = 1, timestamp: Optional[int] = None, timedelta: Optional[int] = None
    ) -> BlockNumber:
        """
        Increase the block height within the test RPC.

        Arguments
        ---------
        blocks : int
            Number of new blocks to be mined
        timestamp : int
            Timestamp of the final block being mined. If multiple blocks
            are mined, they will be placed at equal intervals starting
            at `chain.time()` and ending at `timestamp`.
        timedelta : int
            Timedelta for the final block to be mined. If given, the final
            block will have a timestamp of `chain.time() + timedelta`

        Returns
        -------
        BlockNumber
            Current block height
        """
        if not isinstance(blocks, int):
            raise TypeError("`blocks` must be an integer value")

        if timedelta is not None:
            if timestamp is not None:
                raise ValueError("Cannot use both `timestamp` and `timedelta`")

            timestamp = self.time() + timedelta

        if timestamp is None:
            params: List = [[] for _ in range(blocks)]
        elif blocks == 1:
            params = [[timestamp]]
        else:
            now = self.time()
            duration = (timestamp - now) / (blocks - 1)
            params = [[round(now + duration * i)] for i in range(blocks)]

        for i in range(blocks):
            rpc.Rpc().mine(*params[i])

        if timestamp is not None:
            self.sleep(0)

        self._redo_buffer.clear()
        self._current_id = rpc.Rpc().snapshot()
        return web3.eth.block_number

    def snapshot(self) -> None:
        """
        Take a snapshot of the current state of the EVM.

        This action clears the undo buffer.
        """
        self._undo_buffer.clear()
        self._redo_buffer.clear()
        self._snapshot_id = self._current_id = rpc.Rpc().snapshot()

    def revert(self) -> BlockNumber:
        """
        Revert the EVM to the most recently taken snapshot.

        This action clears the undo buffer.

        Returns
        -------
        BlockNumber
            Current block height
        """
        if self._snapshot_id is None:
            raise ValueError("No snapshot set")
        self._undo_buffer.clear()
        self._redo_buffer.clear()
        self._snapshot_id = self._current_id = self._revert(self._snapshot_id)
        return web3.eth.block_number

    def reset(self) -> BlockNumber:
        """
        Revert the EVM to the initial state when loaded.

        This action clears the undo buffer.

        Returns
        -------
        BlockNumber
            Current block height
        """
        self._snapshot_id = None
        self._undo_buffer.clear()
        self._redo_buffer.clear()
        if self._reset_id is None:
            self._reset_id = self._current_id = rpc.Rpc().snapshot()
            _notify_registry(0)  # type: ignore [arg-type]
        else:
            self._reset_id = self._current_id = self._revert(self._reset_id)
        return web3.eth.block_number

    def undo(self, num: int = 1) -> BlockNumber:
        """
        Undo one or more transactions.

        Arguments
        ---------
        num : int, optional
            Number of transactions to undo.

        Returns
        -------
        BlockNumber
            Current block height
        """
        with self._undo_lock:
            if num < 1:
                raise ValueError("num must be greater than zero")
            if not self._undo_buffer:
                raise ValueError("Undo buffer is empty")
            if num > len(self._undo_buffer):
                raise ValueError(f"Undo buffer contains {len(self._undo_buffer)} items")

            for _ in range(num):
                id_, fn, args, kwargs = self._undo_buffer.pop()
                self._redo_buffer.append((fn, args, kwargs))

            self._current_id = self._revert(id_)
            return web3.eth.block_number

    def redo(self, num: int = 1) -> BlockNumber:
        """
        Redo one or more undone transactions.

        Arguments
        ---------
        num : int, optional
            Number of transactions to redo.

        Returns
        -------
        BlockNumber
            Current block height
        """
        with self._undo_lock:
            if num < 1:
                raise ValueError("num must be greater than zero")
            if not self._redo_buffer:
                raise ValueError("Redo buffer is empty")
            if num > len(self._redo_buffer):
                raise ValueError(f"Redo buffer contains {len(self._redo_buffer)} items")

            for _ in range(num):
                fn, args, kwargs = self._redo_buffer.pop()
                fn(*args, **kwargs)

            return web3.eth.block_number


# objects that will update whenever the RPC is reset or reverted must register
# by calling to this function. The must also include _revert and _reset methods
# to receive notifications from this object
def _revert_register(obj: object) -> None:
    _revert_refs.append(weakref.ref(obj))


def _notify_registry(height: Optional[BlockNumber] = None) -> None:
    gc.collect()
    if height is None:
        height = web3.eth.block_number
    for ref in _revert_refs.copy():
        obj = ref()
        if obj is None:
            _revert_refs.remove(ref)
        elif height:
            obj._revert(height)
        else:
            obj._reset()


def _find_contract(address: Optional[HexAddress]) -> Optional[AnyContract]:
    if address is None:
        return None

    address = _resolve_address(address)
    if address in _contract_map:
        return _contract_map[address]
    if "chainid" not in CONFIG.active_network:
        return None

    from brownie.network.contract import Contract

    try:
        return Contract(address)
    except (ValueError, CompilerError):
        return None


def _get_current_dependencies() -> List[ContractName]:
    dependencies = {v._name for v in _contract_map.values()}
    for contract in _contract_map.values():
        dependencies.update(contract._build.get("dependencies", []))
    return sorted(dependencies)


def _add_contract(contract: AnyContract) -> None:
    _contract_map[contract.address] = contract


def _remove_contract(contract: AnyContract) -> None:
    _contract_map.pop(contract.address, None)


def _get_deployment(
    address: Optional[HexAddress] = None,
    alias: Optional[ContractName] = None,
) -> Deployment | Tuple[None, None]:
    if address and alias:
        raise ValueError("Passed both params address and alias, should be only one!")
    if address:
        address = _resolve_address(address)
        query = f"address='{address}'"
    elif alias:
        query = f"alias='{alias}'"

    try:
        name = f"chain{CONFIG.active_network['chainid']}"
    except KeyError:
        raise BrownieEnvironmentError("Functionality not available in local environment") from None
    try:
        row = cur.fetchone(f"SELECT * FROM {name} WHERE {query}")
    except OperationalError:
        row = None
    if not row:
        return None, None

    keys = ("address", "alias", "paths") + DEPLOYMENT_KEYS
    build_json: ContractBuildJson = dict(zip(keys, row))  # type: ignore [assignment]
    # when we json.dump the path map, the tuples are encoded as lists so we need to make them tuples again.
    path_map: PathMap = {k: tuple(v) for k, v in build_json.pop("paths", {}).items()}  # type: ignore [typeddict-item]
    sources: Dict[str, Any] = {
        i[1]: cur.fetchone("SELECT source FROM sources WHERE hash=?", (i[0],))[0]  # type: ignore [index]
        for i in path_map.values()
    }

    build_json["allSourcePaths"] = {k: path_map[k][1] for k in path_map}
    pc_map: Optional[Dict[int | str, ProgramCounter]] = build_json.get("pcMap")  # type: ignore [assignment]
    if isinstance(pc_map, dict):
        build_json["pcMap"] = PCMap({Count(int(k)): pc_map[k] for k in pc_map})

    return build_json, sources


def _add_deployment(
    contract: "Contract",
    alias: Optional[ContractName] = None,
) -> None:
    if "chainid" not in CONFIG.active_network:
        return

    address = _resolve_address(contract.address)
    name = f"chain{CONFIG.active_network['chainid']}"

    cur.execute(
        f"CREATE TABLE IF NOT EXISTS {name} "
        f"(address UNIQUE, alias UNIQUE, paths, {', '.join(DEPLOYMENT_KEYS)})"
    )

    contract_build = contract._build
    if "compiler" not in contract_build:
        # do not replace full contract artifacts with ABI-only ones
        row = cur.fetchone(f"SELECT compiler FROM {name} WHERE address=?", (address,))
        if row and row[0]:
            return

    all_sources = {}
    source_paths: dict = contract_build.get("allSourcePaths", {})
    if source_paths:
        contract_sources: dict = contract._sources
        for key, path in source_paths.items():
            source = contract_sources.get(path)
            if source is None:
                source = Path(path).read_text()
            hash_ = sha1(source.encode()).hexdigest()
            cur.insert("sources", hash_, source)
            all_sources[key] = [hash_, path]

    values = (contract_build.get(i) for i in DEPLOYMENT_KEYS)
    cur.insert(name, address, alias, all_sources, *values)


def _remove_deployment(
    address: Optional[HexAddress] = None,
    alias: Optional[ContractName] = None,
) -> Deployment | Tuple[None, None]:
    if address and alias:
        raise ValueError("Passed both params address and alias, should be only one!")
    if address:
        address = _resolve_address(address)
        query = f"address='{address}'"
    elif alias:
        query = f"alias='{alias}'"

    try:
        name = f"chain{CONFIG.active_network['chainid']}"
    except KeyError:
        raise BrownieEnvironmentError("Functionality not available in local environment") from None

    deployment = _get_deployment(address, alias)
    # delete entry from chain{n}
    cur.execute(f"DELETE FROM {name} WHERE {query}")
    # delete all entries from sources matching the contract's source hashes
    if contract := _find_contract(address):
        for key, path in contract._build.get("allSourcePaths", {}).items():
            source = contract._sources.get(path)
            if source is None:
                source = Path(path).read_text()
            hash_ = sha1(source.encode()).hexdigest()
            cur.execute(f"DELETE FROM sources WHERE hash='{hash_}'")

    return deployment
